"""
 Copyright (c) 2020 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import tensorflow as tf


def assert_shape_equal(shape_a, shape_b):
    """Asserts that shape_a and shape_b are equal.

    If the shapes are static, raises a ValueError when the shapes
    mismatch.

    If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
    mismatch.

    Args:
      shape_a: a list containing shape of the first tensor.
      shape_b: a list containing shape of the second tensor.

    Returns:
      Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
      when the shapes are dynamic.

    Raises:
      ValueError: When shapes are both static and unequal.
    """
    if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)):
        if shape_a != shape_b:
            raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
        return tf.no_op()
    return tf.assert_equal(shape_a, shape_b)


def combined_static_and_dynamic_shape(tensor):
    """Returns a list containing static and dynamic values for the dimensions.

    Returns a list of static and dynamic values for shape dimensions. This is
    useful to preserve static shapes when available in reshape operation.

    Args:
      tensor: A tensor of any type.

    Returns:
      A list of size tensor.shape.ndims containing integers or a scalar tensor.
    """
    static_tensor_shape = tensor.shape.as_list()
    dynamic_tensor_shape = tf.shape(input=tensor)
    combined_shape = []

    for index, dim in enumerate(static_tensor_shape):
        if dim is not None:
            combined_shape.append(dim)
        else:
            combined_shape.append(dynamic_tensor_shape[index])

    return combined_shape
